import baostock as bs
import pandas as pd
import psycopg2
from sqlalchemy import create_engine

from dao import stock_kdata_dao
from model.stock_kdata_model import StockKData
from system.sys_logger import Logfactory
from utils import uuid_util


# 构建StockKData集合
def build_stock_kdata_list(baostock_kdata_frame: pd.DataFrame):
    try:
        stock_kdata_list = []

        if baostock_kdata_frame is None or len(baostock_kdata_frame) == 0:
            return stock_kdata_list

        for index, row in baostock_kdata_frame.iterrows():
            stock_kdata = build_stock_kdata(row)
            if stock_kdata is not None:
                stock_kdata_list.append(stock_kdata)
        return stock_kdata_list
    except Exception as e:
        Logfactory.logger.error("build_stock_kdata_list exception and error message: %s", e)
        raise e


# 构建StockKData对象
def build_stock_kdata(row: pd.Series):
    try:
        if row is None:
            Logfactory.logger.warn("baostock_kdata is None")
            return None
        else:
            # 交易所行情日期
            date = row['date']
            # 证券代码 sh或sz.+6位数字代码，或者指数代码，如：sh.601398。sh：上海；sz：深圳。此参数不可为空
            code = row['code']
            # 开盘价
            open_price = float(row['open']) if row['open'] else 0.0
            # 最高价
            high_price = float(row['high']) if row['high'] else 0.0
            # 最低价
            low_price = float(row['low']) if row['low'] else 0.0
            # 收盘价
            close_price = float(row['close']) if row['close'] else 0.0
            # 前收盘价
            preclose_price = float(row['preclose']) if row['preclose'] else 0.0
            # 成交量（累计单位：股）
            volume = int(row['volume']) if row['volume'] else 0.0
            # 成交额（单位：人民币元）
            amount = float(row['amount']) if row['amount'] else 0.0
            # 复权状态(1：后复权， 2：前复权，3：不复权）
            adjust_flag = int(row['adjustflag']) if row['adjustflag'] else 0.0
            # 换手率
            turn = float(row['turn']) if row['turn'] else 0.0
            # 交易状态(1：正常交易 0：停牌）
            trade_status = int(row['tradestatus']) if row['tradestatus'] else 0.0
            # 涨跌幅（百分比）
            pct_chg = float(row['pctChg']) if row['pctChg'] else 0.0
            # 滚动市盈率
            pe_ttm = float(row['peTTM']) if row['peTTM'] else 0.0
            # 市净率
            pb_mrq = float(row['pbMRQ']) if row['pbMRQ'] else 0.0
            # 滚动市销率
            ps_ttm = float(row['psTTM']) if row['psTTM'] else 0.0
            # 滚动市现率
            pcf_ncf_ttm = float(row['pcfNcfTTM']) if row['pcfNcfTTM'] else 0.0
            # 是否ST股，1是，0否
            is_st = int(row['isST']) if row['isST'] else 0.0

            stockKData = StockKData(id=uuid_util.generate_uuid(), date=date, code=code, open_price=open_price,
                                    high_price=high_price,
                                    low_price=low_price, close_price=close_price,
                                    preclose_price=preclose_price, volume=volume, amount=amount,
                                    adjust_flag=adjust_flag,
                                    turn=turn, trade_status=trade_status, pct_chg=pct_chg, pe_ttm=pe_ttm
                                    , pb_mrq=pb_mrq, ps_ttm=ps_ttm, pcf_ncf_ttm=pcf_ncf_ttm, is_st=is_st)
            return stockKData
    except Exception as e:
        Logfactory.logger.exception("build_stock_kdata exception and error message is %s", e)
        raise e


"""
获取历史A股K线数据：query_history_k_data_plus()
调用 baostock 获取历史A股K线数据：query_history_k_data_plus()
@see http://baostock.com/baostock/index.php/A%E8%82%A1K%E7%BA%BF%E6%95%B0%E6%8D%AE
获取沪深A股历史K线数据
分钟线指标：date,time,code,open,high,low,close,volume,amount,adjustflag
周月线指标：date,code,open,high,low,close,volume,amount,adjustflag,turn,pctChg
stock_code：股票编码
start_date：开始时间，包括
end_date：结束时间，包括
"""


def get_stock_k_data(stock_code, start_date, end_date):
    try:
        # 定义需要获取的列
        fields = (
            "date, code, open, high, low, close, preclose, volume, amount, adjustflag, turn, tradestatus, pctChg, "
            "peTTM, pbMRQ, psTTM, pcfNcfTTM, isST")

        rs = bs.query_history_k_data_plus(stock_code, fields, start_date=start_date, end_date=end_date, frequency="d",
                                          adjustflag="3")

        if rs.error_code != '0':
            Logfactory.logger.error(
                "query_history_k_data_plus failed . code=%s start_date=%s end_date=%s and respond error_code=%s "
                "error_msg=%s",
                stock_code, start_date, end_date, rs.error_code, rs.error_msg)
        else:
            Logfactory.logger.info(
                "query_history_k_data_plus success . code=%s start_date=%s end_date=%s and respond error_code=%s "
                "error_msg=%s",
                stock_code, start_date, end_date, rs.error_code, rs.error_msg)

            data_list = []
            while rs.next():
                # 获取一条记录，将记录合并在一起
                data_list.append(rs.get_row_data())
            result = pd.DataFrame(data_list, columns=rs.fields)

            # 获取指定股票的指定时间范围的 StockKData 集合
            stock_kdata_list = build_stock_kdata_list(result)

            # 保存到数据库表t_stock_kdata
            stock_kdata_dao.save_all(stock_kdata_list)

            # result.to_csv("history_A_stock_k_data.csv", index=False)
    except Exception as e:
        Logfactory.logger.error("get_stock_k_data xception and stock_code=%s, start_date=%s, end_date=%s",e, stock_code, start_date, end_date)
        raise e


"""
数据保存为本地csv文件
"""


def save_stock_kdata_to_file(stock_code, result: pd.DataFrame):
    if result is None:
        return
    result.to_csv(stock_code + "history_A_stock_k_data.csv", index=False)


def print_dataframe(data_frame: pd.DataFrame):
    if data_frame is not None:
        for index, row in data_frame.iterrows():
            Logfactory.logger.info("%s %s %s", row['date'], row['code'], row['open'])


def get_stock_sh_kdata():
    # 登陆系统
    # 显示登陆返回信息
    # print('login respond error_code:' + lg.error_code)
    # print('login respond  error_msg:' + lg.error_msg)

    # 获取指数(综合指数、规模指数、一级行业指数、二级行业指数、策略指数、成长指数、价值指数、主题指数)K线数据
    # 综合指数，例如：sh.000001 上证指数，sz.399106 深证综指 等；
    # 规模指数，例如：sh.000016 上证50，sh.000300 沪深300，sh.000905 中证500，sz.399001 深证成指等；
    # 一级行业指数，例如：sh.000037 上证医药，sz.399433 国证交运 等；
    # 二级行业指数，例如：sh.000952 300地产，sz.399951 300银行 等；
    # 策略指数，例如：sh.000050 50等权，sh.000982 500等权 等；
    # 成长指数，例如：sz.399376 小盘成长 等；
    # 价值指数，例如：sh.000029 180价值 等；
    # 主题指数，例如：sh.000015 红利指数，sh.000063 上证周期 等；

    # 详细指标参数，参见“历史行情指标参数”章节；“周月线”参数与“日线”参数不同。
    # 周月线指标：date,code,open,high,low,close,volume,amount,adjustflag,turn,pctChg
    rs = bs.query_history_k_data_plus("sh.000001",
                                      "date,code,open,high,low,close,preclose,volume,amount,pctChg",
                                      start_date='2023-03-03', end_date='2023-03-04', frequency="d")
    print('query_history_k_data_plus respond error_code:' + rs.error_code)
    print('query_history_k_data_plus respond  error_msg:' + rs.error_msg)

    # 打印结果集
    data_list = []
    while (rs.error_code == '0') & rs.next():
        # 获取一条记录，将记录合并在一起
        data_list.append(rs.get_row_data())
    result = pd.DataFrame(data_list, columns=rs.fields)
    # 结果集输出到csv文件
    result.to_csv("history_Index_k_data.csv", index=False)
    print(result)
    # 登出系统


def dataframe_insert_table(data_frame, table_name, fields):
    # 创建数据库连接
    # conn = psycopg2.connect(
    #    database="db_stock_analysis",
    #    user="postgres",
    #    password="123456",
    #    host="127.0.0.1",
    #    port="5432",
    # )

    # 设置 PostgreSQL 数据库连接参数
    db_params = {
        'host': '127.0.0.1',
        'database': 'db_stock_analysis',
        'user': 'postgres',
        'password': '123456',
        'port': 5432
    }

    # 使用 psycopg2 创建连接
    conn = psycopg2.connect(**db_params)

    # pgsqlUtil = postgresql_util.PGSQLUtil(host="127.0.0.1", user="postgres", password="123456", database="db_stock_analysis")
    # conn = pgsqlUtil.get_conn()

    # 创建一个 SQLAlchemy 引擎
    engine = create_engine(
        f"postgresql+psycopg2://{db_params['user']}:{db_params['password']}@{db_params['host']}/{db_params['database']}")

    data_frame.to_sql(table_name, engine, if_exists="replace")

    conn.close()
